I downloaded Satellite Images of Clouds from NASA Worldview

I chose three regions (between 21 degs longitude and 14 degs latitude). The images themselves, which are true-color, were taken from two polar-orbiting satellites (TERRA and AQUA). Each of these pass a specific region every day. The imager on these satellites do leave a small footprint which is the result of stitching together images to form one. This is from two successive orbits, and where these do not crossover are marked black.

The labels were created by a team of 68 scientists at the Max-Planck-Institute for Meterology
# This will hopefully be a model for the classification of cloud organization patterns which may help scientists understand how clouds affect our future climate.
In [94]:
!pip install catalyst --upgrade
!pip install --user git+https://github.com/albu/albumentations@bdd6a4e
!pip install --user git+https://github.com/qubvel/segmentation_models.pytorch
Requirement already up-to-date: catalyst in c:\users\shane\lib\site-packages (20.1.3)
Requirement already satisfied, skipping upgrade: torch>=1.0.0 in c:\users\shane\lib\site-packages (from catalyst) (1.4.0)
Requirement already satisfied, skipping upgrade: tensorboard>=1.14.0 in c:\users\shane\lib\site-packages (from catalyst) (2.1.0)
Requirement already satisfied, skipping upgrade: tensorboardX in c:\users\shane\lib\site-packages (from catalyst) (2.0)
Requirement already satisfied, skipping upgrade: crc32c>=1.7 in c:\users\shane\lib\site-packages (from catalyst) (2.0)
Requirement already satisfied, skipping upgrade: safitty>=1.2.3 in c:\users\shane\lib\site-packages (from catalyst) (1.3)
Requirement already satisfied, skipping upgrade: matplotlib in c:\users\shane\lib\site-packages (from catalyst) (3.1.2)
Requirement already satisfied, skipping upgrade: tqdm>=4.33.0 in c:\users\shane\lib\site-packages (from catalyst) (4.42.0)
Requirement already satisfied, skipping upgrade: PyYAML in c:\users\shane\lib\site-packages (from catalyst) (5.3)
Requirement already satisfied, skipping upgrade: opencv-python in c:\users\shane\lib\site-packages (from catalyst) (4.1.2.30)
Requirement already satisfied, skipping upgrade: Pillow<7 in c:\users\shane\lib\site-packages (from catalyst) (6.2.2)
Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in c:\users\shane\lib\site-packages (from catalyst) (1.18.1)
Requirement already satisfied, skipping upgrade: packaging in c:\users\shane\lib\site-packages (from catalyst) (20.1)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.20 in c:\users\shane\lib\site-packages (from catalyst) (0.22.1)
Requirement already satisfied, skipping upgrade: torchvision>=0.2.1 in c:\users\shane\lib\site-packages (from catalyst) (0.5.0)
Requirement already satisfied, skipping upgrade: ipython in c:\users\shane\lib\site-packages (from catalyst) (7.11.1)
Requirement already satisfied, skipping upgrade: pandas>=0.22 in c:\users\shane\lib\site-packages (from catalyst) (1.0.0)
Requirement already satisfied, skipping upgrade: seaborn in c:\users\shane\lib\site-packages (from catalyst) (0.10.0)
Requirement already satisfied, skipping upgrade: imageio in c:\users\shane\lib\site-packages (from catalyst) (2.6.1)
Requirement already satisfied, skipping upgrade: GitPython>=2.1.11 in c:\users\shane\lib\site-packages (from catalyst) (3.0.5)
Requirement already satisfied, skipping upgrade: scikit-image>=0.14.2 in c:\users\shane\lib\site-packages (from catalyst) (0.16.2)
Requirement already satisfied, skipping upgrade: plotly>=4.1.0 in c:\users\shane\lib\site-packages (from catalyst) (4.5.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (3.11.2)
Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (45.1.0)
Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.11.0)
Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.16.1)
Requirement already satisfied, skipping upgrade: grpcio>=1.24.3 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.26.0)
Requirement already satisfied, skipping upgrade: six>=1.10.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (1.14.0)
Requirement already satisfied, skipping upgrade: requests<3,>=2.21.0 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (2.22.0)
Requirement already satisfied, skipping upgrade: absl-py>=0.4 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.9.0)
Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= "3" in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.33.6)
Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (0.4.1)
Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in c:\users\shane\lib\site-packages (from tensorboard>=1.14.0->catalyst) (3.1.1)
Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (2.4.6)
Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (1.1.0)
Requirement already satisfied, skipping upgrade: cycler>=0.10 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (0.10.0)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in c:\users\shane\lib\site-packages (from matplotlib->catalyst) (2.8.1)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in c:\users\shane\lib\site-packages (from scikit-learn>=0.20->catalyst) (1.4.1)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in c:\users\shane\lib\site-packages (from scikit-learn>=0.20->catalyst) (0.14.1)
Requirement already satisfied, skipping upgrade: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in c:\users\shane\lib\site-packages (from ipython->catalyst) (3.0.3)
Requirement already satisfied, skipping upgrade: decorator in c:\users\shane\lib\site-packages (from ipython->catalyst) (4.4.1)
Requirement already satisfied, skipping upgrade: colorama; sys_platform == "win32" in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.4.3)
Requirement already satisfied, skipping upgrade: pickleshare in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.7.5)
Requirement already satisfied, skipping upgrade: jedi>=0.10 in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.16.0)
Requirement already satisfied, skipping upgrade: traitlets>=4.2 in c:\users\shane\lib\site-packages (from ipython->catalyst) (4.3.3)
Requirement already satisfied, skipping upgrade: backcall in c:\users\shane\lib\site-packages (from ipython->catalyst) (0.1.0)
Requirement already satisfied, skipping upgrade: pygments in c:\users\shane\lib\site-packages (from ipython->catalyst) (2.5.2)
Requirement already satisfied, skipping upgrade: pytz>=2017.2 in c:\users\shane\lib\site-packages (from pandas>=0.22->catalyst) (2019.3)
Requirement already satisfied, skipping upgrade: gitdb2>=2.0.0 in c:\users\shane\lib\site-packages (from GitPython>=2.1.11->catalyst) (2.0.6)
Requirement already satisfied, skipping upgrade: networkx>=2.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.14.2->catalyst) (2.4)
Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.14.2->catalyst) (1.1.1)
Requirement already satisfied, skipping upgrade: retrying>=1.3.3 in c:\users\shane\lib\site-packages (from plotly>=4.1.0->catalyst) (1.3.3)
Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (0.2.8)
Requirement already satisfied, skipping upgrade: rsa<4.1,>=3.1.4 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (4.0)
Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in c:\users\shane\lib\site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (4.0.0)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (2.7)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (2018.1.18)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (3.0.4)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in c:\users\shane\lib\site-packages (from requests<3,>=2.21.0->tensorboard>=1.14.0->catalyst) (1.24.3)
Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in c:\users\shane\lib\site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst) (1.3.0)
Requirement already satisfied, skipping upgrade: wcwidth in c:\users\shane\lib\site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->catalyst) (0.1.8)
Requirement already satisfied, skipping upgrade: parso>=0.5.2 in c:\users\shane\lib\site-packages (from jedi>=0.10->ipython->catalyst) (0.6.0)
Requirement already satisfied, skipping upgrade: ipython-genutils in c:\users\shane\lib\site-packages (from traitlets>=4.2->ipython->catalyst) (0.2.0)
Requirement already satisfied, skipping upgrade: smmap2>=2.0.0 in c:\users\shane\lib\site-packages (from gitdb2>=2.0.0->GitPython>=2.1.11->catalyst) (2.0.5)
Requirement already satisfied, skipping upgrade: pyasn1<0.5.0,>=0.4.6 in c:\users\shane\lib\site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=1.14.0->catalyst) (0.4.8)
Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in c:\users\shane\lib\site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14.0->catalyst) (3.1.0)
Collecting git+https://github.com/albu/albumentations@bdd6a4e
  Cloning https://github.com/albu/albumentations (to revision bdd6a4e) to c:\users\shane\appdata\local\temp\pip-req-build-4fbyt0fz
Requirement already satisfied: numpy>=1.11.1 in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (1.18.1)
Requirement already satisfied: scipy in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (1.4.1)
Collecting opencv-python-headless
  Using cached opencv_python_headless-4.1.2.30-cp36-cp36m-win_amd64.whl (33.0 MB)
Requirement already satisfied: imgaug<0.2.7,>=0.2.5 in c:\users\shane\lib\site-packages (from albumentations==0.2.2) (0.2.6)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.14.0)
Requirement already satisfied: scikit-image>=0.11.0 in c:\users\shane\lib\site-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (0.16.2)
Requirement already satisfied: networkx>=2.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.4)
Requirement already satisfied: imageio>=2.3.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.6.1)
Requirement already satisfied: PyWavelets>=0.4.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.1.1)
Requirement already satisfied: pillow>=4.3.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (6.2.2)
Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in c:\users\shane\lib\site-packages (from scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (3.1.2)
Requirement already satisfied: decorator>=4.3.0 in c:\users\shane\lib\site-packages (from networkx>=2.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (4.4.1)
Requirement already satisfied: python-dateutil>=2.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.8.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (2.4.6)
Requirement already satisfied: cycler>=0.10 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\shane\lib\site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (1.1.0)
Requirement already satisfied: setuptools in c:\users\shane\lib\site-packages (from kiwisolver>=1.0.1->matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.11.0->imgaug<0.2.7,>=0.2.5->albumentations==0.2.2) (45.1.0)
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py): started
  Building wheel for albumentations (setup.py): finished with status 'done'
  Created wheel for albumentations: filename=albumentations-0.2.2-py3-none-any.whl size=40824 sha256=b2e3ffaadcca719c262909ef28a7e8a0f248e338e42d26f0644e4a822a11d1d7
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-borrdapx\wheels\31\7a\63\9e858e89b0e44cb4f3621b0ce0c077363fbe546b04b1dcc0ba
Successfully built albumentations
Installing collected packages: opencv-python-headless, albumentations
Successfully installed albumentations-0.2.2 opencv-python-headless-4.1.2.30
  Running command git clone -q https://github.com/albu/albumentations 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-4fbyt0fz'
  WARNING: Did not find branch or tag 'bdd6a4e', assuming revision or ref.
  Running command git checkout -q bdd6a4e
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to c:\users\shane\appdata\local\temp\pip-req-build-81_4swz8
Requirement already satisfied (use --upgrade to upgrade): segmentation-models-pytorch==0.1.0 from git+https://github.com/qubvel/segmentation_models.pytorch in c:\users\shane\lib\site-packages
Requirement already satisfied: torchvision>=0.3.0 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.5.0)
Requirement already satisfied: pretrainedmodels==0.7.4 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.7.4)
Requirement already satisfied: efficientnet-pytorch>=0.5.1 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.6.1)
Requirement already satisfied: pillow>=4.1.1 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (6.2.2)
Requirement already satisfied: numpy in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.18.1)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.14.0)
Requirement already satisfied: torch==1.4.0 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.4.0)
Requirement already satisfied: tqdm in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (4.42.0)
Requirement already satisfied: munch in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (2.5.0)
Building wheels for collected packages: segmentation-models-pytorch
  Building wheel for segmentation-models-pytorch (setup.py): started
  Building wheel for segmentation-models-pytorch (setup.py): finished with status 'done'
  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.1.0-py3-none-any.whl size=47303 sha256=9ac583953fc7e0a1fbecd2fa936f171fe00869e093e4d8becd6008405cc741f0
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-nwv_e3p5\wheels\53\e5\fc\18292d80d3c0f4efc96cbbb72625fdbafdca303997bacfb085
Successfully built segmentation-models-pytorch
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-81_4swz8'
In [96]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to c:\users\shane\appdata\local\temp\pip-req-build-ycb81jmm
Requirement already satisfied (use --upgrade to upgrade): segmentation-models-pytorch==0.1.0 from git+https://github.com/qubvel/segmentation_models.pytorch in c:\users\shane\lib\site-packages
Requirement already satisfied: torchvision>=0.3.0 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.5.0)
Requirement already satisfied: pretrainedmodels==0.7.4 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.7.4)
Requirement already satisfied: efficientnet-pytorch>=0.5.1 in c:\users\shane\lib\site-packages (from segmentation-models-pytorch==0.1.0) (0.6.1)
Requirement already satisfied: torch==1.4.0 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.4.0)
Requirement already satisfied: numpy in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.18.1)
Requirement already satisfied: six in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (1.14.0)
Requirement already satisfied: pillow>=4.1.1 in c:\users\shane\lib\site-packages (from torchvision>=0.3.0->segmentation-models-pytorch==0.1.0) (6.2.2)
Requirement already satisfied: tqdm in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (4.42.0)
Requirement already satisfied: munch in c:\users\shane\lib\site-packages (from pretrainedmodels==0.7.4->segmentation-models-pytorch==0.1.0) (2.5.0)
Building wheels for collected packages: segmentation-models-pytorch
  Building wheel for segmentation-models-pytorch (setup.py): started
  Building wheel for segmentation-models-pytorch (setup.py): finished with status 'done'
  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.1.0-py3-none-any.whl size=47303 sha256=56ddfc20cd708a02e4574f9f1b0631132032479e178a18bdcb8d2297455726de
  Stored in directory: C:\Users\Shane\AppData\Local\Temp\pip-ephem-wheel-cache-jjy8xgvx\wheels\53\e5\fc\18292d80d3c0f4efc96cbbb72625fdbafdca303997bacfb085
Successfully built segmentation-models-pytorch
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch 'C:\Users\Shane\AppData\Local\Temp\pip-req-build-ycb81jmm'
In [1]:
import os
import cv2
import collections
import time 
import tqdm
from PIL import Image
from functools import partial
train_on_gpu = True
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
In [90]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR



import segmentation_models_pytorch as smp
model = smp.Unet()
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\Shane/.cache\torch\checkpoints\resnet34-333f7ec4.pth

In [68]:
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor
import albumentations as albu
from albumentations import pytorch as AT
In [148]:
from catalyst.dl.utils import criterion
In [151]:
from catalyst.data import Augmentor
from catalyst.dl import utils
from catalyst.data.reader import ImageReader, ScalarReader, ReaderCompose, LambdaReader
from catalyst.dl.runner import SupervisedRunner
#from catalyst.contrib.models.segmentation import Unet
from catalyst.dl.callbacks import DiceCallback, EarlyStoppingCallback, InferCallback, CheckpointCallback
In [44]:
def get_img(x, folder: str='train_images'):
    """
    Return image based on image name and folder.
    """
    data_folder = f"{path}/{folder}"
    image_path = os.path.join(data_folder, x)
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def rle_decode(mask_rle: str = '', shape: tuple = (1400, 2100)):
    '''
    Decode rle encoded mask.
    
    :param mask_rle: run-length as string formatted (start length)
    :param shape: (height, width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')


def make_mask(df: pd.DataFrame, image_name: str='img.jpg', shape: tuple = (1400, 2100)):
    """
    Create mask based on df, image name and shape.
    """
    encoded_masks = df.loc[df['im_id'] == image_name, 'EncodedPixels']
    masks = np.zeros((shape[0], shape[1], 4), dtype=np.float32)

    for idx, label in enumerate(encoded_masks.values):
        if label is not np.nan:
            mask = rle_decode(label)
            masks[:, :, idx] = mask
            
    return masks


def to_tensor(x, **kwargs):
    """
    Convert image or mask.
    """
    return x.transpose(2, 0, 1).astype('float32')


def mask2rle(img):
    '''
    Convert mask to rle.
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def visualize(image, mask, original_image=None, original_mask=None):
    """
    Plot image and masks.
    If two pairs of images and masks are passes, show both.
    """
    fontsize = 14
    class_dict = {0: 'Fish', 1: 'Flower', 2: 'Gravel', 3: 'Sugar'}
    
    if original_image is None and original_mask is None:
        f, ax = plt.subplots(1, 5, figsize=(24, 24))

        ax[0].imshow(image)
        for i in range(4):
            ax[i + 1].imshow(mask[:, :, i])
            ax[i + 1].set_title(f'Mask {class_dict[i]}', fontsize=fontsize)
    else:
        f, ax = plt.subplots(2, 5, figsize=(24, 12))

        ax[0, 0].imshow(original_image)
        ax[0, 0].set_title('Original image', fontsize=fontsize)
                
        for i in range(4):
            ax[0, i + 1].imshow(original_mask[:, :, i])
            ax[0, i + 1].set_title(f'Original mask {class_dict[i]}', fontsize=fontsize)
        
        ax[1, 0].imshow(image)
        ax[1, 0].set_title('Transformed image', fontsize=fontsize)
        
        
        for i in range(4):
            ax[1, i + 1].imshow(mask[:, :, i])
            ax[1, i + 1].set_title(f'Transformed mask {class_dict[i]}', fontsize=fontsize)
            
            
def visualize_with_raw(image, mask, original_image=None, original_mask=None, raw_image=None, raw_mask=None):
    """
    Plot image and masks.
    If two pairs of images and masks are passes, show both.
    """
    fontsize = 14
    class_dict = {0: 'Fish', 1: 'Flower', 2: 'Gravel', 3: 'Sugar'}

    f, ax = plt.subplots(3, 5, figsize=(24, 12))

    ax[0, 0].imshow(original_image)
    ax[0, 0].set_title('Original image', fontsize=fontsize)

    for i in range(4):
        ax[0, i + 1].imshow(original_mask[:, :, i])
        ax[0, i + 1].set_title(f'Original mask {class_dict[i]}', fontsize=fontsize)


    ax[1, 0].imshow(raw_image)
    ax[1, 0].set_title('Original image', fontsize=fontsize)

    for i in range(4):
        ax[1, i + 1].imshow(raw_mask[:, :, i])
        ax[1, i + 1].set_title(f'Raw predicted mask {class_dict[i]}', fontsize=fontsize)
        
    ax[2, 0].imshow(image)
    ax[2, 0].set_title('Transformed image', fontsize=fontsize)


    for i in range(4):
        ax[2, i + 1].imshow(mask[:, :, i])
        ax[2, i + 1].set_title(f'Predicted mask with processing {class_dict[i]}', fontsize=fontsize)
            
            
def plot_with_augmentation(image, mask, augment):
    """
    Wrapper for `visualize` function.
    """
    augmented = augment(image=image, mask=mask)
    image_flipped = augmented['image']
    mask_flipped = augmented['mask']
    visualize(image_flipped, mask_flipped, original_image=image, original_mask=mask)
    
    
sigmoid = lambda x: 1 / (1 + np.exp(-x))


def post_process(probability, threshold, min_size):
    """
    Post processing of each predicted mask, components with lesser number of pixels
    than `min_size` are ignored
    """
    # don't remember where I saw it
    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = np.zeros((350, 525), np.float32)
    num = 0
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            predictions[p] = 1
            num += 1
    return predictions, num


def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0),
        albu.GridDistortion(p=0.5),
        albu.OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5),
        albu.Resize(320, 640)
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(320, 640)
    ]
    return albu.Compose(test_transform)


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)


def dice(img1, img2):
    img1 = np.asarray(img1).astype(np.bool)
    img2 = np.asarray(img2).astype(np.bool)

    intersection = np.logical_and(img1, img2)

    return 2. * intersection.sum() / (img1.sum() + img2.sum())
Data Overview
In [45]:
path = '../cosmology'
os.listdir(path)
Out[45]:
['.ipynb_checkpoints',
 'BTC Mk1.ipynb',
 'crypto_tradinds.csv',
 'cumulative.csv',
 'exoTest.csv',
 'exoTrain.csv',
 'historical-data-on-the-trading-of-cryptocurrencies.zip',
 'New folder',
 'sample_submission.csv',
 'test_images',
 'train.csv',
 'train_images',
 'Untitled.ipynb',
 'Untitled1.ipynb']
In [46]:
train = pd.read_csv(f'{path}/train.csv')
sub = pd.read_csv(f'{path}/sample_submission.csv')
In [47]:
train.head()
Out[47]:
Image_Label EncodedPixels
0 0011165.jpg_Fish 264918 937 266318 937 267718 937 269118 937 27...
1 0011165.jpg_Flower 1355565 1002 1356965 1002 1358365 1002 1359765...
2 0011165.jpg_Gravel NaN
3 0011165.jpg_Sugar NaN
4 002be4f.jpg_Fish 233813 878 235213 878 236613 878 238010 881 23...
In [48]:
n_train = len(os.listdir(f'{path}/train_images'))
n_test = len(os.listdir(f'{path}/test_images'))
print(f'There are {n_train} images in train dataset')
print(f'There are {n_test} images in test dataset')
There are 5546 images in train dataset
There are 3698 images in test dataset
In [49]:
train['Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
Out[49]:
Sugar     5546
Flower    5546
Gravel    5546
Fish      5546
Name: Image_Label, dtype: int64
We have 5546 images in the training dataset; they can have up to 4 masks
In [50]:
train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
Out[50]:
Sugar     3751
Gravel    2939
Fish      2781
Flower    2365
Name: Image_Label, dtype: int64
In [51]:
train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[0]).value_counts().value_counts()
Out[51]:
2    2372
3    1560
1    1348
4     266
Name: Image_Label, dtype: int64
Seems to be lots of empty masks. Only 266 have all four.
In [52]:
train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])


sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])
Let's look at the images and the masks
In [53]:
fig = plt.figure(figsize=(25, 16))
for j, im_id in enumerate(np.random.choice(train['im_id'].unique(), 4)):
    for i, (idx, row) in enumerate(train.loc[train['im_id'] == im_id].iterrows()):
        ax = fig.add_subplot(5, 4, j * 4 + i + 1, xticks=[], yticks=[])
        im = Image.open(f"{path}/train_images/{row['Image_Label'].split('_')[0]}")
        plt.imshow(im)
        mask_rle = row['EncodedPixels']
        try: # label might not be there!
            mask = rle_decode(mask_rle)
        except:
            mask = np.zeros((1400, 2100))
        plt.imshow(mask, alpha=0.5, cmap='gray')
        ax.set_title(f"Image: {row['Image_Label'].split('_')[0]}. Label: {row['label']}")
we see that some masks overlap, they're often pretty large, and they're similar to their labels.

Preparing data for Modeling

We're going to create a list of unique image ID's and the count of masks for the images. We'll make a stratified split based off of this count
In [54]:
id_mask_count = train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(lambda x: x.split('_')[0]).value_counts().\
reset_index().rename(columns={'index': 'img_id', 'Image_Label': 'count'})
train_ids, valid_ids = train_test_split(id_mask_count['img_id'].values, random_state=42, stratify=id_mask_count['count'], test_size=0.1)
test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values

Augmentations with Albumentations

Lots of good augmentations from albumentations, I chose these at random
In [55]:
image_name = '8242ba0.jpg'
image = get_img(image_name)
mask = make_mask(train, image_name)
In [56]:
visualize(image, mask)

as you can see, original image on left, rest are the masks; going to add some augmentations

In [57]:
plot_with_augmentation(image, mask, albu.HorizontalFlip(p=1))
In [58]:
plot_with_augmentation(image, mask, albu.VerticalFlip(p=1))
In [59]:
plot_with_augmentation(image, mask, albu.RandomRotate90(p=1))
In [60]:
plot_with_augmentation(image, mask, albu.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03))
In [63]:
plot_with_augmentation(image, mask, albu.GridDistortion(p=0.1))
In [64]:
plot_with_augmentation(image, mask, albu.OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5))

Setting up the data training in Catalyst

In [71]:
class CloudDataset(Dataset):
    def __init__(self, df: pd.DataFrame = None, datatype: str = 'train', img_ids: np.array = None,
                 transforms = albu.Compose([albu.HorizontalFlip(),AT.ToTensor()]),
                preprocessing=None):
        self.df = df
        if datatype != 'test':
            self.data_folder = f"{path}/train_images"
        else:
            self.data_folder = f"{path}/test_images"
        self.img_ids = img_ids
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __getitem__(self, idx):
        image_name = self.img_ids[idx]
        mask = make_mask(self.df, image_name)
        image_path = os.path.join(self.data_folder, image_name)
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        augmented = self.transforms(image=img, mask=mask)
        img = augmented['image']
        mask = augmented['mask']
        if self.preprocessing:
            preprocessed = self.preprocessing(image=img, mask=mask)
            img = preprocessed['image']
            mask = preprocessed['mask']
        return img, mask

    def __len__(self):
        return len(self.img_ids)
In [72]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'

ACTIVATION = None
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=4, 
    activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
In [144]:
num_workers = 0
bs = 16
train_dataset = CloudDataset(df=train, datatype='train', img_ids=train_ids, transforms = get_training_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
valid_dataset = CloudDataset(df=train, datatype='valid', img_ids=valid_ids, transforms = get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=num_workers)

loaders = {
    "train": train_loader,
    "valid": valid_loader
}
In [145]:
 
In [158]:
num_epochs = 19
logdir = "./logs/segmentation"

# model, criterion, optimizer
optimizer = torch.optim.Adam([
    {'params': model.decoder.parameters(), 'lr': 1e-2}, 
    {'params': model.encoder.parameters(), 'lr': 1e-3},  
])

scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
#criterion = smp.utils.losses.BCELoss(None)
criterion = nn.CrossEntropyLoss()
runner = SupervisedRunner()

Model Training

## MOst likely some issue with terms having changed due to deprecation.
In [159]:
runner = SupervisedRunner()
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=[DiceCallback(), EarlyStoppingCallback(patience=5, min_delta=0.001)],
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True
)






















1/19 * Epoch (train):   0% 0/312 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-159-20d5dcd769a5> in <module>
      9     logdir=logdir,
     10     num_epochs=num_epochs,
---> 11     verbose=True
     12 )

c:\users\shane\lib\site-packages\catalyst\dl\runner\supervised.py in train(self, model, criterion, optimizer, loaders, logdir, callbacks, scheduler, resume, num_epochs, valid_loader, main_metric, minimize_metric, verbose, state_kwargs, checkpoint_data, fp16, monitoring_params, check)
    204             monitoring_params=monitoring_params
    205         )
--> 206         self.run_experiment(experiment, check=check)
    207 
    208     def infer(

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    380             else:
    381                 self.state.exception = ex
--> 382                 self._run_event("exception", moment=None)
    383 
    384         return self

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    229                 (moment == "end" or moment is None):  # for on_exception case
    230             for logger in self.loggers.values():
--> 231                 getattr(logger, fn_name)(self.state)
    232 
    233         if self.state is not None:

c:\users\shane\lib\site-packages\catalyst\dl\callbacks\misc.py in on_exception(self, state)
    150 
    151         if state.need_reraise_exception:
--> 152             raise exception
    153 
    154 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    372         try:
    373             for stage in self.experiment.stages:
--> 374                 self._run_stage(stage)
    375         except (Exception, KeyboardInterrupt) as ex:
    376             # if an exception had been raised

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_stage(self, stage)
    341 
    342             self._run_event("epoch", moment="start")
--> 343             self._run_epoch(stage=stage, epoch=epoch)
    344             self._run_event("epoch", moment="end")
    345 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_epoch(self, stage, epoch)
    330             self._run_event("loader", moment="start")
    331             with torch.set_grad_enabled(self.state.need_backward):
--> 332                 self._run_loader(loader)
    333             self._run_event("loader", moment="end")
    334 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_loader(self, loader)
    290         self.state.timer.start("_timers/data_time")
    291 
--> 292         for i, batch in enumerate(loader):
    293             self._run_batch(batch)
    294 

c:\users\shane\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

c:\users\shane\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

c:\users\shane\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

c:\users\shane\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-71-29d9038310f8> in __getitem__(self, idx)
     18         img = cv2.imread(image_path)
     19         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
---> 20         augmented = self.transforms(image=img, mask=mask)
     21         img = augmented['image']
     22         mask = augmented['mask']

c:\users\shane\lib\site-packages\albumentations\core\composition.py in __call__(self, force_apply, **data)
    174                     p.preprocess(data)
    175 
--> 176             data = t(force_apply=force_apply, **data)
    177 
    178             if dual_start_end is not None and idx == dual_start_end[1]:

c:\users\shane\lib\site-packages\albumentations\core\transforms_interface.py in __call__(self, force_apply, **kwargs)
     85                     )
     86                 kwargs[self.save_key][id(self)] = deepcopy(params)
---> 87             return self.apply_with_params(params, **kwargs)
     88 
     89         return kwargs

c:\users\shane\lib\site-packages\albumentations\core\transforms_interface.py in apply_with_params(self, params, force_apply, **kwargs)
     98                 target_function = self._get_target_function(key)
     99                 target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
--> 100                 res[key] = target_function(arg, **dict(params, **target_dependencies))
    101             else:
    102                 res[key] = None

c:\users\shane\lib\site-packages\albumentations\augmentations\transforms.py in apply(self, img, stepsx, stepsy, interpolation, **params)
   1218 
   1219     def apply(self, img, stepsx=[], stepsy=[], interpolation=cv2.INTER_LINEAR, **params):
-> 1220         return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value)
   1221 
   1222     def apply_to_mask(self, img, stepsx=[], stepsy=[], **params):

c:\users\shane\lib\site-packages\albumentations\augmentations\functional.py in wrapped_function(img, *args, **kwargs)
     52     def wrapped_function(img, *args, **kwargs):
     53         shape = img.shape
---> 54         result = func(img, *args, **kwargs)
     55         result = result.reshape(shape)
     56         return result

c:\users\shane\lib\site-packages\albumentations\augmentations\functional.py in grid_distortion(img, num_steps, xsteps, ysteps, interpolation, border_mode, value)
   1079             cur = prev + y_step * ysteps[idx]
   1080 
-> 1081         yy[start:end] = np.linspace(prev, cur, end - start)
   1082         prev = cur
   1083 

<__array_function__ internals> in linspace(*args, **kwargs)

c:\users\shane\lib\site-packages\numpy\core\function_base.py in linspace(start, stop, num, endpoint, retstep, dtype, axis)
    122 
    123     if num < 0:
--> 124         raise ValueError("Number of samples, %s, must be non-negative." % num)
    125     div = (num - 1) if endpoint else num
    126 

ValueError: Number of samples, -175, must be non-negative.
In [88]:
utils.plot_metrics(
    logdir=logdir, 
    # specify which metrics we want to plot
    metrics=["loss", "dice", 'lr', '_base/lr']
)
In [160]:
encoded_pixels = []
loaders = {"infer": valid_loader}
runner.infer(
    model=model,
    loaders=loaders,
    callbacks=[
        CheckpointCallback(
            resume=f"{logdir}/checkpoints/best.pth"),
        InferCallback()
    ],
)
valid_masks = []
probabilities = np.zeros((2220, 350, 525))
for i, (batch, output) in enumerate(tqdm.tqdm(zip(
        valid_dataset, runner.callbacks[0].predictions["logits"]))):
    image, mask = batch
    for m in mask:
        if m.shape != (350, 525):
            m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
        valid_masks.append(m)

    for j, probability in enumerate(output):
        if probability.shape != (350, 525):
            probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
        probabilities[i * 4 + j, :, :] = probability
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-160-f2477811b7df> in <module>
      7         CheckpointCallback(
      8             resume=f"{logdir}/checkpoints/best.pth"),
----> 9         InferCallback()
     10     ],
     11 )

c:\users\shane\lib\site-packages\catalyst\dl\runner\supervised.py in infer(self, model, loaders, callbacks, verbose, state_kwargs, fp16, check)
    248             distributed_params=fp16
    249         )
--> 250         self.run_experiment(experiment, check=check)
    251 
    252     def predict_loader(

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    380             else:
    381                 self.state.exception = ex
--> 382                 self._run_event("exception", moment=None)
    383 
    384         return self

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    229                 (moment == "end" or moment is None):  # for on_exception case
    230             for logger in self.loggers.values():
--> 231                 getattr(logger, fn_name)(self.state)
    232 
    233         if self.state is not None:

c:\users\shane\lib\site-packages\catalyst\dl\callbacks\misc.py in on_exception(self, state)
    150 
    151         if state.need_reraise_exception:
--> 152             raise exception
    153 
    154 

c:\users\shane\lib\site-packages\catalyst\core\runner.py in run_experiment(self, experiment, check)
    372         try:
    373             for stage in self.experiment.stages:
--> 374                 self._run_stage(stage)
    375         except (Exception, KeyboardInterrupt) as ex:
    376             # if an exception had been raised

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_stage(self, stage)
    336         self._prepare_for_stage(stage)
    337 
--> 338         self._run_event("stage", moment="start")
    339         for epoch in range(self.state.num_epochs):
    340             self.state.stage_epoch = epoch

c:\users\shane\lib\site-packages\catalyst\core\runner.py in _run_event(self, event, moment)
    223         if self.callbacks is not None:
    224             for callback in self.callbacks.values():
--> 225                 getattr(callback, fn_name)(self.state)
    226 
    227         # after callbacks

c:\users\shane\lib\site-packages\catalyst\core\callbacks\checkpoint.py in on_stage_start(self, state)
    212 
    213         if self.resume is not None:
--> 214             self.load_checkpoint(filename=self.resume, state=state)
    215 
    216     def on_epoch_end(self, state: _State):

c:\users\shane\lib\site-packages\catalyst\core\callbacks\checkpoint.py in load_checkpoint(filename, state)
    125             )
    126         else:
--> 127             raise Exception(f"No checkpoint found at {filename}")
    128 
    129     def get_metric(self, last_valid_metrics) -> Dict:

Exception: No checkpoint found at ./logs/segmentation/checkpoints/best.pth
In [99]:
torch.cuda.is_available 
Out[99]:
<function torch.cuda.is_available()>
In [100]:
x.cuda()      
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-100-4da82bb84806> in <module>
----> 1 x.cuda()

NameError: name 'x' is not defined
In [ ]: